# Copyright 2025 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
import traceback

import numpy as np
import pandas as pd
import torch

from openfold3.core.data.framework.single_datasets.abstract_single import (
    register_dataset,
)
from openfold3.core.data.framework.single_datasets.base_of3 import BaseOF3Dataset
from openfold3.core.data.framework.single_datasets.dataset_utils import (
    pad_to_world_size,
)
from openfold3.core.data.framework.single_datasets.pdb import is_invalid_feature_dict
from openfold3.core.data.primitives.featurization.structure import (
    extract_starts_entities,
)
from openfold3.core.utils.atomize_utils import broadcast_token_feat_to_atoms

logger = logging.getLogger(__name__)


def make_chain_pair_mask_padded(
    all_chains: torch.Tensor, interfaces_to_include: list[tuple[int, int]]
) -> torch.Tensor:
    """Creates a pairwise mask for chains given a list of chain tuples.
    Args:
        all_chains: tensor containing all chain ids in complex
        interfaces_to_include: tuples with pairwise interactions to include
    Returns:
        torch.Tensor [n_chains + 1, n_chains + 1] where:
            - each value [i,j] represents
            - a 0th row and 0th column of all zeros is added as padding
    """
    largest_chain_index = torch.max(all_chains)
    chain_mask = torch.zeros(
        (largest_chain_index + 1, largest_chain_index + 1), dtype=torch.int
    )

    for interface_tuple in interfaces_to_include:
        chain_mask[interface_tuple[0], interface_tuple[1]] = 1
        chain_mask[interface_tuple[1], interface_tuple[0]] = 1

    return chain_mask


@register_dataset
class ValidationPDBDataset(BaseOF3Dataset):
    """Validation Dataset class."""

    def __init__(self, dataset_config: dict, world_size: int | None = None) -> None:
        """Initializes a ValidationDataset.

        Args:
            dataset_config (dict):
                Input config. See openfold3/examples/pdb_sample_dataset_config.yml for
                an example.
        """
        super().__init__(dataset_config)

        self.world_size = world_size

        # Dataset/datapoint cache
        self.create_datapoint_cache()

        # Cropping is turned off
        self.apply_crop = False

    def create_datapoint_cache(self):
        """Creates the datapoint_cache for iterating over each sample.

        Creates a Dataframe storing a flat list of structure_data keys. Used for mapping
        TO the dataset_cache in the getitem. Note that the validation set is not wrapped
        in a StochasticSamplerDataset.
        """
        # Order by token count so that the run times are more consistent across GPUs
        pdb_ids = list(self.dataset_cache.structure_data.keys())

        def null_safe_token_count(x):
            token_count = self.dataset_cache.structure_data[x].token_count
            return token_count if token_count is not None else 0

        pdb_ids = sorted(
            pdb_ids,
            key=null_safe_token_count,
        )
        _datapoint_cache = pd.DataFrame({"pdb_id": pdb_ids})
        self.datapoint_cache = pad_to_world_size(_datapoint_cache, self.world_size)

    def __getitem__(
        self, index: int
    ) -> dict[str : torch.Tensor | dict[str, torch.Tensor]]:
        """Returns a single datapoint from the dataset.

        Note: The data pipeline is modularized at the getitem level to enable
        subclassing for profiling without code duplication. See
        logging_datasets.py for an example."""

        # Get PDB ID from the datapoint cache and the preferred chain/interface
        datapoint = self.datapoint_cache.iloc[index]
        pdb_id = datapoint["pdb_id"]
        is_repeated_sample = bool(datapoint["repeated_sample"])

        if not self.debug_mode:
            sample_data = self.create_all_features(
                pdb_id=pdb_id,
                preferred_chain_or_interface=None,
                return_atom_arrays=True,
                return_crop_strategy=False,
            )
            features = sample_data["features"]
            features["pdb_id"] = pdb_id
            features["preferred_chain_or_interface"] = "none"
            features["repeated_sample"] = torch.tensor(
                [is_repeated_sample], dtype=torch.bool
            )

            return features

        else:
            try:
                sample_data = self.create_all_features(
                    pdb_id=pdb_id,
                    preferred_chain_or_interface=None,
                    return_atom_arrays=True,
                    return_crop_strategy=False,
                )

                features = sample_data["features"]
                features["pdb_id"] = pdb_id
                features["preferred_chain_or_interface"] = "none"

                if is_invalid_feature_dict(features):
                    index = random.randint(0, len(self) - 1)
                    return self.__getitem__(index)

                features["repeated_sample"] = torch.tensor(
                    [is_repeated_sample], dtype=torch.bool
                )

                return features

            except Exception as e:
                tb = traceback.format_exc()
                logger.warning(
                    "-" * 40
                    + "\n"
                    + f"Failed to process ValidationPDBDataset entry {pdb_id}:"
                    + f" {str(e)}\n"
                    + f"Exception type: {type(e).__name__}\nTraceback: {tb}"
                    + "-" * 40
                )
                index = random.randint(0, len(self) - 1)
                return self.__getitem__(index)

    def get_validation_homology_features(self, pdb_id: str, sample_data: dict) -> dict:
        """Create masks for validation metrics analysis.

        Args:
            pdb_id: PDB id for example found in dataset_cache
            sample_data: dictionary containing features for the sample and atom array
        Returns:
            dict with two features:
            - use_for_intra_validation [*, n_tokens]
                mask indicating if token should be used for intrachain metrics
            - use_for_inter_validation [*, n_tokens]
                mask indicating if token should be used for intrachain metrics
        """
        features = {}

        structure_entry = self.dataset_cache.structure_data[pdb_id]

        chains_for_intra_metrics = [
            int(cid)
            for cid, cdata in structure_entry.chains.items()
            if cdata.use_metrics
        ]

        interfaces_to_include = []
        for interface_id, cluster_data in structure_entry.interfaces.items():
            if cluster_data.use_metrics:
                interface_chains = tuple(int(ci) for ci in interface_id.split("_"))
                interfaces_to_include.append(interface_chains)

        # Create token mask for validation intra and inter metrics
        atom_array = sample_data["atom_array"]
        token_starts_with_stop, _ = extract_starts_entities(atom_array)
        token_starts = token_starts_with_stop[:-1]
        token_chain_id = atom_array.chain_id[token_starts].astype(int)

        token_mask = sample_data["features"]["token_mask"]
        num_atoms_per_token = sample_data["features"]["num_atoms_per_token"]

        use_for_intra = torch.tensor(
            np.isin(token_chain_id, chains_for_intra_metrics),
            dtype=torch.int32,
        )
        intra_filter_atomized = broadcast_token_feat_to_atoms(
            token_mask=token_mask,
            num_atoms_per_token=num_atoms_per_token,
            token_feat=use_for_intra,
        ).bool()

        features["intra_filter_atomized"] = intra_filter_atomized

        token_chain_id = torch.tensor(token_chain_id, dtype=torch.int32)
        chain_mask_padded = make_chain_pair_mask_padded(
            token_chain_id, interfaces_to_include
        )

        # [n_token, n_token] for pairwise interactions
        use_for_inter = chain_mask_padded[
            token_chain_id.unsqueeze(0), token_chain_id.unsqueeze(1)
        ]

        # convert use_for_inter: [*, n_token, n_token] into [*, n_atom, n_atom]
        inter_filter_atomized = broadcast_token_feat_to_atoms(
            token_mask=token_mask,
            num_atoms_per_token=num_atoms_per_token,
            token_feat=use_for_inter,
            token_dim=-2,
        )
        inter_filter_atomized = broadcast_token_feat_to_atoms(
            token_mask=token_mask,
            num_atoms_per_token=num_atoms_per_token,
            token_feat=inter_filter_atomized.transpose(-1, -2),
            token_dim=-2,
        )
        inter_filter_atomized = inter_filter_atomized.transpose(-1, -2).bool()
        features["inter_filter_atomized"] = inter_filter_atomized

        return features

    def create_all_features(
        self,
        pdb_id: str,
        preferred_chain_or_interface: str | None,
        return_atom_arrays: bool,
        return_crop_strategy: bool,
    ) -> dict:
        """Calls the parent create_all_features, and then adds features for homology
        similarity."""
        sample_data = super().create_all_features(
            pdb_id,
            preferred_chain_or_interface,
            return_atom_arrays=return_atom_arrays,
            return_crop_strategy=return_crop_strategy,
        )

        validation_homology_filters = self.get_validation_homology_features(
            pdb_id, sample_data
        )
        sample_data["features"]["ground_truth"].update(validation_homology_filters)
        sample_data["features"]["atom_array"] = sample_data["atom_array"]

        # Remove atom arrays if they are not needed
        if not return_atom_arrays:
            del sample_data["atom_array"]
            del sample_data["atom_array_gt"]
            del sample_data["atom_array_cropped"]

        return sample_data
